from collections.abc import Sequence
from typing import Optional

import hail as hl
from hail.expr.expressions import Expression
from hail.expr.expressions.typed_expressions import (
    ArrayExpression,
    CallExpression,
    LocusExpression,
    NumericExpression,
    StructExpression,
)
from hail.genetics.allele_type import AlleleType
from hail.methods.misc import require_first_key_field_locus
from hail.methods.qc import _qc_allele_type
from hail.table import Table
from hail.typecheck import nullable, sequenceof, typecheck
from hail.utils.java import Env
from hail.utils.misc import divide_null
from hail.vds.variant_dataset import VariantDataset


@typecheck(global_gt=Expression, alleles=ArrayExpression)
def vmt_sample_qc_variant_annotations(
    *,
    global_gt: 'Expression',
    alleles: 'ArrayExpression',
) -> tuple['Expression', 'Expression']:
    """Compute the necessary variant annotations for :func:`.vmt_sample_qc`, that is,
    allele count (AC) and an integer representation of allele type.

    Parameters
    ----------
    global_gt : :class:`.Expression`
        Call expression of the global GT of a variants matrix table usually generated
        by :func:`..lgt_to_gt`
    alleles : :class:`.ArrayExpression`
        Array expression of the alleles of a variants matrix table
        (generally ``vds.variant_data.alleles``)

    Returns
    -------
    :class:`tuple`
        Tuple of expressions representing the AC (first element) and allele type
        (second element).
    """

    return (hl.agg.call_stats(global_gt, alleles).AC, alleles[1:].map(lambda alt: _qc_allele_type(alleles[0], alt)))


@typecheck(
    global_gt=Expression,
    gq=Expression,
    variant_ac=ArrayExpression,
    variant_atypes=ArrayExpression,
    dp=nullable(Expression),
    gq_bins=sequenceof(int),
    dp_bins=sequenceof(int),
)
def vmt_sample_qc(
    *,
    global_gt: 'CallExpression',
    gq: 'Expression',
    variant_ac: 'ArrayExpression',
    variant_atypes: 'ArrayExpression',
    dp: Optional['Expression'] = None,
    gq_bins: 'Sequence[int]' = (0, 20, 60),
    dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30),
) -> 'Expression':
    """Computes sample quality metrics from variant data of a VDS

    Parameters
    ----------
    global_gt : :class:`.CallExpression`
        Global GT of a variants matrix table or subset thereof (ex. ``hl.agg.group_by``).
    gq : :class:`.Expression`
        GQ of a variants matrix table.
    variant_ac : :class:`.ArrayExpression`
        Allele counts of a the genotypes of a variants matrix table. This can
        be generated by ``hl.agg.call_stats`` or alternatively
        :func:`.vmt_sample_qc_variant_annotations` (which calls ``call_stats``
        internally)
    variant_atypes : :class:`.ArrayExpression`
        Allele types of the alternate alleles a variants matrix table. This
        must be generated with :func:`.vmt_sample_qc_variant_annotations` in
        order to return correct results.
    dp : :class:`.Expression` or :obj:`NoneType`
        DP of a variants matrix table (or ``None``)
    gq_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for genotype quality (GQ) scores.
    dp_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for depth (DP) scores.

    Returns
    -------
    :class:`.StructExpression`
        A struct expression of type::

            struct{
                bases_over_gq_threshold: tuple(int64 * len(gq_bins)),
                bases_over_dp_threshold: tuple(int64 * len(gq_bins)),  # present if dp is not None
                n_het: int64,
                n_hom_var: int64,
                n_non_ref: int64,
                n_singleton: int64,
                n_singleton_ti: int64,
                n_singleton_tv: int64,
                n_snp: int64,
                n_insertion: int64,
                n_deletion: int64,
                n_transition: int64,
                n_transversion: int64,
                n_star: int64,
                r_ti_tv: float64,
                r_ti_tv_singleton: float64,
                r_het_hom_var: float64,
                r_insertion_deletion: float64,
            }

    """
    bound_exprs = {}

    bound_exprs['n_het'] = hl.agg.count_where(global_gt.is_het())
    bound_exprs['n_hom_var'] = hl.agg.count_where(global_gt.is_hom_var())
    bound_exprs['n_singleton'] = hl.agg.sum(
        hl.rbind(
            global_gt,
            lambda global_gt: hl.sum(
                hl.range(0, global_gt.ploidy).map(
                    lambda i: hl.rbind(global_gt[i], lambda gti: (gti != 0) & (variant_ac[gti] == 1))
                )
            ),
        )
    )
    bound_exprs['n_singleton_ti'] = hl.agg.sum(
        hl.rbind(
            global_gt,
            lambda global_gt: hl.sum(
                hl.range(0, global_gt.ploidy).map(
                    lambda i: hl.rbind(
                        global_gt[i],
                        lambda gti: (gti != 0)
                        & (variant_ac[gti] == 1)
                        & (variant_atypes[gti - 1] == AlleleType.TRANSITION),
                    )
                )
            ),
        )
    )
    bound_exprs['n_singleton_tv'] = hl.agg.sum(
        hl.rbind(
            global_gt,
            lambda global_gt: hl.sum(
                hl.range(0, global_gt.ploidy).map(
                    lambda i: hl.rbind(
                        global_gt[i],
                        lambda gti: (gti != 0)
                        & (variant_ac[gti] == 1)
                        & (variant_atypes[gti - 1] == AlleleType.TRANSVERSION),
                    )
                )
            ),
        )
    )

    bound_exprs['allele_type_counts'] = hl.agg.explode(
        lambda allele_type: hl.tuple(hl.agg.count_where(allele_type == i) for i in range(len(AlleleType))),
        (
            hl.range(0, global_gt.ploidy)
            .map(lambda i: global_gt[i])
            .filter(lambda allele_idx: allele_idx > 0)
            .map(lambda allele_idx: variant_atypes[allele_idx - 1])
        ),
    )

    dp_exprs = {}
    if dp is not None:
        dp_exprs['bases_over_dp_threshold'] = hl.tuple(hl.agg.count_where(dp >= x) for x in dp_bins)

    gq_dp_exprs = {'bases_over_gq_threshold': hl.tuple(hl.agg.count_where(gq >= x) for x in gq_bins), **dp_exprs}

    return hl.rbind(
        hl.struct(**bound_exprs),
        lambda x: hl.rbind(
            hl.struct(**{
                **gq_dp_exprs,
                'n_het': x.n_het,
                'n_hom_var': x.n_hom_var,
                'n_non_ref': x.n_het + x.n_hom_var,
                'n_singleton': x.n_singleton,
                'n_singleton_ti': x.n_singleton_ti,
                'n_singleton_tv': x.n_singleton_tv,
                'n_snp': x.allele_type_counts[AlleleType.TRANSITION] + x.allele_type_counts[AlleleType.TRANSVERSION],
                'n_insertion': x.allele_type_counts[AlleleType.INSERTION],
                'n_deletion': x.allele_type_counts[AlleleType.DELETION],
                'n_transition': x.allele_type_counts[AlleleType.TRANSITION],
                'n_transversion': x.allele_type_counts[AlleleType.TRANSVERSION],
                'n_star': x.allele_type_counts[AlleleType.STAR],
            }),
            lambda s: s.annotate(
                r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion),
                r_ti_tv_singleton=divide_null(hl.float64(s.n_singleton_ti), s.n_singleton_tv),
                r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var),
                r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion),
            ),
        ),
    )


@typecheck(
    locus=LocusExpression,
    gq=NumericExpression,
    end=NumericExpression,
    dp=nullable(Expression),
    gq_bins=sequenceof(int),
    dp_bins=sequenceof(int),
)
def rmt_sample_qc(
    *,
    locus: 'LocusExpression',
    end: 'NumericExpression',
    gq: 'NumericExpression',
    dp: Optional['Expression'] = None,
    gq_bins: 'Sequence[int]' = (0, 20, 60),
    dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30),
) -> 'StructExpression':
    """Computes sample quality metrics from reference data of a VDS
    Parameters
    ----------
    locus : :class:`.LocusExpression`
        Locus of a refrence matrix table
    end : :class:`.NumericExpression`
        END of a reference matrix table
    gq : :class:`.Expression`
        GQ of a variants matrix table.
    dp : :class:`.Expression` or :obj:`NoneType`
        DP of a variants matrix table (or ``None``)
    gq_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for genotype quality (GQ) scores.
    dp_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for depth (DP) scores.

    Returns
    -------
    :class:`.StructExpression`
        A struct expression of type::

            struct{
                bases_over_gq_threshold: tuple(int64 * len(gq_bins)),
                bases_over_dp_threshold: tuple(int64 * len(dp_bins)),  # present if dp is not None
            }

    """
    ref_dp_expr = {}
    if dp is not None:
        ref_dp_expr['bases_over_dp_threshold'] = hl.tuple(
            hl.agg.filter(dp >= x, hl.agg.sum(1 + end - locus.position)) for x in dp_bins
        )
    return hl.struct(
        bases_over_gq_threshold=hl.tuple(hl.agg.filter(gq >= x, hl.agg.sum(1 + end - locus.position)) for x in gq_bins),
        **ref_dp_expr,
    )


def combine_sample_qc(
    rmt_sample_qc: Expression,
    vmt_sample_qc: Expression,
) -> Expression:
    """Combine reference and variants sample quality results
    Parameters
    ----------
    rmt_sample_qc : :class:`.Expression`
        A struct expression produced by :func:`.rmt_sample_qc`
    vmt_sample_qc : :class:`.Expression`
        A struct expression produced by :func:`.vmt_sample_qc`

    Returns
    -------
    :class:`.StructExpression`
        A struct expression of type::

            struct{
                bases_over_gq_threshold:
                    tuple(int64 * len(rmt_sample_qc.bases_over_gq_threshold)),
                bases_over_dp_threshold:  # present if dp was present for qc stats generation
                    tuple(int64 * len(rmt_sample_qc.bases_over_dp_threshold)),
            }

    Note
    ----
    It is the responsibility of the caller of this function to make sure that
    the ``gq_bins`` and ``dp_bins`` that are used for the generation of both of
    the arguments to this function are the same. Incorrect results will occur
    if the bins are not the same. This function checks the length of the bins
    used, but cannot check the bin values themselves.
    """
    if 'bases_over_gq_threshold' not in rmt_sample_qc:
        raise ValueError("Expect 'bases_over_gq_threshold' field in 'rmt_sample_qc' expression")
    if 'bases_over_gq_threshold' not in vmt_sample_qc:
        raise ValueError("Expect 'bases_over_gq_threshold' field in 'vmt_sample_qc' expression")
    if sum('bases_over_dp_threshold' in expr for expr in (rmt_sample_qc, vmt_sample_qc)) % 2 == 1:
        raise ValueError(
            "Expect 'bases_over_dp_threshold' field in both or neither of 'rmt_sample_qc' and 'vmt_sample_qc'"
        )
    if len(rmt_sample_qc.bases_over_gq_threshold) != len(vmt_sample_qc.bases_over_gq_threshold):
        raise ValueError("Expect same number of GQ bins for both variant and reference qc results")
    if 'bases_over_dp_threshold' in rmt_sample_qc and len(rmt_sample_qc.bases_over_dp_threshold) != len(
        vmt_sample_qc.bases_over_dp_threshold
    ):
        raise ValueError("Expect same number of DP bins for both variant and reference qc results")

    joined_dp_expr = {}
    if 'bases_over_dp_threshold' in vmt_sample_qc:
        joined_dp_expr['bases_over_dp_threshold'] = hl.tuple(
            x + y for x, y in zip(vmt_sample_qc.bases_over_dp_threshold, rmt_sample_qc.bases_over_dp_threshold)
        )

    return hl.struct(
        bases_over_gq_threshold=hl.tuple(
            x + y for x, y in zip(vmt_sample_qc.bases_over_gq_threshold, rmt_sample_qc.bases_over_gq_threshold)
        ),
        **joined_dp_expr,
    )


@typecheck(vds=VariantDataset, gq_bins=sequenceof(int), dp_bins=sequenceof(int), dp_field=nullable(str))
def sample_qc(
    vds: 'VariantDataset',
    *,
    gq_bins: 'Sequence[int]' = (0, 20, 60),
    dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30),
    dp_field=None,
) -> 'Table':
    """Compute sample quality metrics about a :class:`.VariantDataset`.

    If the `dp_field` parameter is not specified, the ``DP`` is used for depth
    if present. If no ``DP`` field is present, the ``MIN_DP`` field is used. If no ``DP``
    or ``MIN_DP`` field is present, no depth statistics will be calculated.

    Parameters
    ----------
    vds : :class:`.VariantDataset`
        Dataset in VariantDataset representation.
    gq_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for genotype quality (GQ) scores.
    dp_bins : :class:`tuple` of :obj:`int`
        Tuple containing cutoffs for depth (DP) scores.
    dp_field : :obj:`str`
        Name of depth field. If not supplied, DP or MIN_DP will be used, in that order.

    Returns
    -------
    :class:`.Table`
        Hail Table of results, keyed by sample.
    """

    require_first_key_field_locus(vds.reference_data, 'sample_qc')
    require_first_key_field_locus(vds.variant_data, 'sample_qc')

    if dp_field is not None:
        ref_dp_field_to_use = dp_field
    elif 'DP' in vds.reference_data.entry:
        ref_dp_field_to_use = 'DP'
    elif 'MIN_DP' in vds.reference_data.entry:
        ref_dp_field_to_use = 'MIN_DP'
    else:
        ref_dp_field_to_use = None

    vmt = vds.variant_data
    if 'GT' not in vmt.entry:
        vmt = vmt.annotate_entries(GT=hl.vds.lgt_to_gt(vmt.LGT, vmt.LA))
    allele_count, atypes = vmt_sample_qc_variant_annotations(global_gt=vmt.GT, alleles=vmt.alleles)
    variant_ac = Env.get_uid()
    variant_atypes = Env.get_uid()
    vmt = vmt.annotate_rows(**{variant_ac: allele_count, variant_atypes: atypes})
    vmt_dp = vmt['DP'] if ref_dp_field_to_use is not None and 'DP' in vmt.entry else None
    variant_results = vmt.select_cols(
        **vmt_sample_qc(
            global_gt=vmt.GT,
            gq=vmt.GQ,
            variant_ac=vmt[variant_ac],
            variant_atypes=vmt[variant_atypes],
            dp=vmt_dp,
            gq_bins=gq_bins,
            dp_bins=dp_bins,
        )
    ).cols()

    rmt = vds.reference_data
    rmt_dp = rmt[ref_dp_field_to_use] if ref_dp_field_to_use is not None else None
    reference_results = rmt.select_cols(
        **rmt_sample_qc(
            locus=rmt.locus,
            gq=rmt.GQ,
            end=rmt.END,
            dp=rmt_dp,
            gq_bins=gq_bins,
            dp_bins=dp_bins,
        )
    ).cols()

    joined = reference_results[variant_results.key]
    dp_bins_field = {}
    if ref_dp_field_to_use is not None:
        dp_bins_field['dp_bins'] = hl.tuple(dp_bins)
    joined_results = variant_results.transmute(**combine_sample_qc(joined, variant_results.row))
    joined_results = joined_results.annotate_globals(gq_bins=hl.tuple(gq_bins), **dp_bins_field)
    return joined_results
